/*
 * Copyright 2004-2011 H2 Group.
 * Copyright 2011 James Moger.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.iciql;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import javax.sql.DataSource;

import com.iciql.DbUpgrader.DefaultDbUpgrader;
import com.iciql.Iciql.IQTable;
import com.iciql.Iciql.IQVersion;
import com.iciql.util.JdbcUtils;
import com.iciql.util.IciqlLogger;
import com.iciql.util.StringUtils;
import com.iciql.util.Utils;
import com.iciql.util.WeakIdentityHashMap;

/**
 * This class represents a connection to a database.
 */

public class Db {

	/**
	 * This map It holds unique tokens that are generated by functions such as
	 * Function.sum(..) in "db.from(p).select(Function.sum(p.unitPrice))". It
	 * doesn't actually hold column tokens, as those are bound to the query
	 * itself.
	 */
	private static final Map<Object, Token> TOKENS;

	private static final Map<String, Class<? extends SQLDialect>> DIALECTS;

	private final Connection conn;
	private final Map<Class<?>, TableDefinition<?>> classMap = Collections
			.synchronizedMap(new HashMap<Class<?>, TableDefinition<?>>());
	private final SQLDialect dialect;
	private DbUpgrader dbUpgrader = new DefaultDbUpgrader();
	private final Set<Class<?>> upgradeChecked = Collections.synchronizedSet(new HashSet<Class<?>>());

	static {
		TOKENS = Collections.synchronizedMap(new WeakIdentityHashMap<Object, Token>());
		DIALECTS = Collections.synchronizedMap(new HashMap<String, Class<? extends SQLDialect>>());
		// can register by...
		// 1. Connection class name
		// 2. DatabaseMetaData.getDatabaseProductName()
		DIALECTS.put("Apache Derby", SQLDialectDerby.class);
		DIALECTS.put("H2", SQLDialectH2.class);
		DIALECTS.put("HSQL Database Engine", SQLDialectHSQL.class);
		DIALECTS.put("MySQL", SQLDialectMySQL.class);
		DIALECTS.put("PostgreSQL", SQLDialectPostgreSQL.class);
	}

	private Db(Connection conn) {
		this.conn = conn;
		String databaseName = null;
		DatabaseMetaData data = null;
		try {
			data = conn.getMetaData();
			databaseName = data.getDatabaseProductName();
		} catch (SQLException s) {
			throw new IciqlException(s, "failed to retrieve database metadata!");
		}
		dialect = getDialect(databaseName, conn.getClass().getName());
		dialect.configureDialect(databaseName, data);
	}

	/**
	 * Register a new/custom dialect class. You can use this method to replace
	 * any existing dialect or to add a new one.
	 * 
	 * @param token
	 *            the fully qualified name of the connection class or the
	 *            expected result of DatabaseMetaData.getDatabaseProductName()
	 * @param dialectClass
	 *            the dialect class to register
	 */
	public static void registerDialect(String token, Class<? extends SQLDialect> dialectClass) {
		DIALECTS.put(token, dialectClass);
	}

	SQLDialect getDialect(String databaseName, String className) {
		Class<? extends SQLDialect> dialectClass = null;
		if (DIALECTS.containsKey(className)) {
			// dialect registered by connection class name
			dialectClass = DIALECTS.get(className);
		} else if (DIALECTS.containsKey(databaseName)) {
			// dialect registered by database name
			dialectClass = DIALECTS.get(databaseName);
		} else {
			// did not find a match, use default
			dialectClass = SQLDialectDefault.class;
		}
		return instance(dialectClass);
	}

	static <X> X registerToken(X x, Token token) {
		TOKENS.put(x, token);
		return x;
	}

	static Token getToken(Object x) {
		return TOKENS.get(x);
	}

	static <T> T instance(Class<T> clazz) {
		try {
			return clazz.newInstance();
		} catch (Exception e) {
			throw new IciqlException(e);
		}
	}

	public static Db open(String url) {
		try {
			Connection conn = JdbcUtils.getConnection(null, url, null, null);
			return new Db(conn);
		} catch (SQLException e) {
			throw new IciqlException(e);
		}
	}
	
	public static Db open(String url, String user, String password) {
		try {
			Connection conn = JdbcUtils.getConnection(null, url, user, password);
			return new Db(conn);
		} catch (SQLException e) {
			throw new IciqlException(e);
		}
	}

	/**
	 * Create a new database instance using a data source. This method is fast,
	 * so that you can always call open() / close() on usage.
	 * 
	 * @param ds
	 *            the data source
	 * @return the database instance.
	 */
	public static Db open(DataSource ds) {
		try {
			return new Db(ds.getConnection());
		} catch (SQLException e) {
			throw new IciqlException(e);
		}
	}

	public static Db open(Connection conn) {
		return new Db(conn);
	}

	public static Db open(String url, String user, char[] password) {
		try {
			Properties prop = new Properties();
			prop.setProperty("user", user);
			prop.put("password", password);
			Connection conn = JdbcUtils.getConnection(null, url, prop);
			return new Db(conn);
		} catch (SQLException e) {
			throw new IciqlException(e);
		}
	}

	public <T> void insert(T t) {
		Class<?> clazz = t.getClass();
		long rc = define(clazz).createIfRequired(this).insert(this, t, false);
		if (rc == 0) {
			throw new IciqlException("Failed to insert {0}.  Affected rowcount == 0.", t);
		}
	}

	public <T> long insertAndGetKey(T t) {
		Class<?> clazz = t.getClass();
		return define(clazz).createIfRequired(this).insert(this, t, true);
	}

	/**
	 * Merge INSERTS if the record does not exist or UPDATES the record if it
	 * does exist. Not all databases support MERGE and the syntax varies with
	 * the database.
	 * 
	 * If the database does not support a MERGE syntax the dialect can try to
	 * simulate a merge by implementing:
	 * <p>
	 * INSERT INTO foo... (SELECT ?,... FROM foo WHERE pk=? HAVING count(*)=0)
	 * <p>
	 * iciql will check the affected row count returned by the internal merge
	 * method and if the affected row count = 0, it will issue an update.
	 * <p>
	 * See the Derby dialect for an implementation of this technique.
	 * <p>
	 * If the dialect does not support merge an IciqlException will be thrown.
	 * 
	 * @param t
	 */
	public <T> void merge(T t) {
		Class<?> clazz = t.getClass();
		TableDefinition<?> def = define(clazz).createIfRequired(this);
		int rc = def.merge(this, t);
		if (rc == 0) {
			rc = def.update(this, t);
		}
		if (rc == 0) {
			throw new IciqlException("merge failed");
		}
	}

	public <T> int update(T t) {
		Class<?> clazz = t.getClass();
		return define(clazz).createIfRequired(this).update(this, t);
	}

	public <T> int delete(T t) {
		Class<?> clazz = t.getClass();
		return define(clazz).createIfRequired(this).delete(this, t);
	}

	public <T extends Object> Query<T> from(T alias) {
		Class<?> clazz = alias.getClass();
		define(clazz).createIfRequired(this);
		return Query.from(this, alias);
	}

	@SuppressWarnings("unchecked")
	public <T> int dropTable(Class<? extends T> modelClass) {
		TableDefinition<T> def = (TableDefinition<T>) define(modelClass);
		SQLStatement stat = new SQLStatement(this);
		getDialect().prepareDropTable(stat, def);
		IciqlLogger.drop(stat.getSQL());
		int rc = 0;
		try {
			rc = stat.executeUpdate();
		} catch (IciqlException e) {
			if (e.getIciqlCode() != IciqlException.CODE_OBJECT_NOT_FOUND) {
				throw e;
			}
		}
		// remove this model class from the table definition cache
		classMap.remove(modelClass);
		return rc;
	}

	@SuppressWarnings("unchecked")
	public <T> List<T> buildObjects(Class<? extends T> modelClass, ResultSet rs) {
		List<T> result = new ArrayList<T>();
		TableDefinition<T> def = (TableDefinition<T>) define(modelClass);
		try {
			while (rs.next()) {
				T item = Utils.newObject(modelClass);
				def.readRow(item, rs);
				result.add(item);
			}
		} catch (SQLException e) {
			throw new IciqlException(e);
		}
		return result;
	}

	Db upgradeDb() {
		if (!upgradeChecked.contains(dbUpgrader.getClass())) {
			// flag as checked immediately because calls are nested.
			upgradeChecked.add(dbUpgrader.getClass());

			IQVersion model = dbUpgrader.getClass().getAnnotation(IQVersion.class);
			if (model.value() > 0) {
				DbVersion v = new DbVersion();
				// (SCHEMA="" && TABLE="") == DATABASE
				DbVersion dbVersion = from(v).where(v.schemaName).is("").and(v.tableName).is("")
						.selectFirst();
				if (dbVersion == null) {
					// database has no version registration, but model specifies
					// version: insert DbVersion entry and return.
					DbVersion newDb = new DbVersion(model.value());
					// database is an older version than the model
					boolean success = dbUpgrader.upgradeDatabase(this, 0, newDb.version);
					if (success) {
						insert(newDb);
					}
				} else {
					// database has a version registration:
					// check to see if upgrade is required.
					if ((model.value() > dbVersion.version) && (dbUpgrader != null)) {
						// database is an older version than the model
						boolean success = dbUpgrader.upgradeDatabase(this, dbVersion.version, model.value());
						if (success) {
							dbVersion.version = model.value();
							update(dbVersion);
						}
					}
				}
			}
		}
		return this;
	}

	<T> void upgradeTable(TableDefinition<T> model) {
		if (!upgradeChecked.contains(model.getModelClass())) {
			// flag is checked immediately because calls are nested
			upgradeChecked.add(model.getModelClass());

			if (model.tableVersion > 0) {
				// table is using iciql version tracking.
				DbVersion v = new DbVersion();
				String schema = StringUtils.isNullOrEmpty(model.schemaName) ? "" : model.schemaName;
				DbVersion dbVersion = from(v).where(v.schemaName).is(schema).and(v.tableName)
						.is(model.tableName).selectFirst();
				if (dbVersion == null) {
					// table has no version registration, but model specifies
					// version: insert DbVersion entry
					DbVersion newTable = new DbVersion(model.tableVersion);
					newTable.schemaName = schema;
					newTable.tableName = model.tableName;
					insert(newTable);
				} else {
					// table has a version registration:
					// check if upgrade is required
					if ((model.tableVersion > dbVersion.version) && (dbUpgrader != null)) {
						// table is an older version than model
						boolean success = dbUpgrader.upgradeTable(this, schema, model.tableName,
								dbVersion.version, model.tableVersion);
						if (success) {
							dbVersion.version = model.tableVersion;
							update(dbVersion);
						}
					}
				}
			}
		}
	}

	<T> TableDefinition<T> define(Class<T> clazz) {
		TableDefinition<T> def = getTableDefinition(clazz);
		if (def == null) {
			upgradeDb();
			def = new TableDefinition<T>(clazz);
			def.mapFields();
			classMap.put(clazz, def);
			if (Iciql.class.isAssignableFrom(clazz)) {
				T t = instance(clazz);
				Iciql table = (Iciql) t;
				Define.define(def, table);
			} else if (clazz.isAnnotationPresent(IQTable.class)) {
				// annotated classes skip the Define().define() static
				// initializer
				T t = instance(clazz);
				def.mapObject(t);
			}
		}
		return def;
	}

	public synchronized void setDbUpgrader(DbUpgrader upgrader) {
		if (!upgrader.getClass().isAnnotationPresent(IQVersion.class)) {
			throw new IciqlException("DbUpgrader must be annotated with " + IQVersion.class.getSimpleName());
		}
		this.dbUpgrader = upgrader;
		upgradeChecked.clear();
	}

	public SQLDialect getDialect() {
		return dialect;
	}

	public Connection getConnection() {
		return conn;
	}

	public void close() {
		try {
			conn.close();
		} catch (Exception e) {
			throw new IciqlException(e);
		}
	}

	public <A> TestCondition<A> test(A x) {
		return new TestCondition<A>(x);
	}

	public <T> void insertAll(List<T> list) {
		for (T t : list) {
			insert(t);
		}
	}

	public <T> List<Long> insertAllAndGetKeys(List<T> list) {
		List<Long> identities = new ArrayList<Long>();
		for (T t : list) {
			identities.add(insertAndGetKey(t));
		}
		return identities;
	}

	public <T> void updateAll(List<T> list) {
		for (T t : list) {
			update(t);
		}
	}

	public <T> void deleteAll(List<T> list) {
		for (T t : list) {
			delete(t);
		}
	}

	PreparedStatement prepare(String sql, boolean returnGeneratedKeys) {
		IciqlException.checkUnmappedField(sql);
		try {
			if (returnGeneratedKeys) {
				return conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
			}
			return conn.prepareStatement(sql);
		} catch (SQLException e) {
			throw IciqlException.fromSQL(sql, e);
		}
	}

	@SuppressWarnings("unchecked")
	<T> TableDefinition<T> getTableDefinition(Class<T> clazz) {
		return (TableDefinition<T>) classMap.get(clazz);
	}

	/**
	 * Run a SQL query directly against the database.
	 * 
	 * Be sure to close the ResultSet with
	 * 
	 * <pre>
	 * JdbcUtils.closeSilently(rs, true);
	 * </pre>
	 * 
	 * @param sql
	 *            the SQL statement
	 * @param args
	 *            optional object arguments for x=? tokens in query
	 * @return the result set
	 */
	public ResultSet executeQuery(String sql, Object... args) {
		try {
			if (args.length == 0) {
				return conn.createStatement().executeQuery(sql);
			} else {
				PreparedStatement stat = conn.prepareStatement(sql);
				int i = 1;
				for (Object arg : args) {
					stat.setObject(i++, arg);
				}
				return stat.executeQuery();
			}
		} catch (SQLException e) {
			throw new IciqlException(e);
		}
	}

	/**
	 * Run a SQL query directly against the database and map the results to the
	 * model class.
	 * 
	 * @param modelClass
	 *            the model class to bind the query ResultSet rows into.
	 * @param sql
	 *            the SQL statement
	 * @return the result set
	 */
	public <T> List<T> executeQuery(Class<? extends T> modelClass, String sql, Object... args) {
		ResultSet rs = null;
		try {
			if (args.length == 0) {
				rs = conn.createStatement().executeQuery(sql);
			} else {
				PreparedStatement stat = conn.prepareStatement(sql);
				int i = 1;
				for (Object arg : args) {
					stat.setObject(i++, arg);
				}
				rs = stat.executeQuery();
			}
			return buildObjects(modelClass, rs);
		} catch (SQLException e) {
			throw new IciqlException(e);
		} finally {
			JdbcUtils.closeSilently(rs, true);
		}
	}

	/**
	 * Run a SQL statement directly against the database.
	 * 
	 * @param sql
	 *            the SQL statement
	 * @return the update count
	 */
	public int executeUpdate(String sql) {
		Statement stat = null;
		try {
			stat = conn.createStatement();
			int updateCount = stat.executeUpdate(sql);
			return updateCount;
		} catch (SQLException e) {
			throw new IciqlException(e);
		} finally {
			JdbcUtils.closeSilently(stat);
		}
	}
}
